--- ../../swvox/swvox/swvox.py	2024-02-14 05:16:57.541407536 -0500
+++ ../google/swvox/swvox/swvox.py	2024-02-13 10:46:24.576645015 -0500
@@ -30,7 +30,7 @@
 import numpy as np
 import math
 from torch import nn, autograd
-from svox.helpers import N3TreeView, DataFormat, LocalIndex, _get_c_extension
+from swvox.helpers import N3TreeView, DataFormat, LocalIndex, WaveletType, _get_c_extension
 from warnings import warn
 
 _C = _get_c_extension()
@@ -53,6 +53,25 @@
                          grad_out.contiguous()), None, None
         return None, None, None
 
+class _QueryVerticalFunctionFullPath(autograd.Function):
+    @staticmethod
+    def forward(ctx, data, tree_spec, indices):
+        out, node_ids = _C.query_vertical_path(tree_spec, indices)
+
+        ctx.mark_non_differentiable(node_ids)
+        ctx.tree_spec = tree_spec
+        ctx.save_for_backward(indices)
+        return out, node_ids
+
+    @staticmethod
+    def backward(ctx, grad_out, dummy):
+        if ctx.needs_input_grad[0]:
+            return _C.query_vertical_path_backward(ctx.tree_spec,
+                         ctx.saved_tensors[0],
+                         grad_out.contiguous()), None, None
+        return None, None, None
+
+
 
 class N3Tree(nn.Module):
     """
@@ -67,14 +86,17 @@
         or :code:`expand(), shrink()` is used,
         please re-make any optimizers
     """
-    def __init__(self, N=2, data_dim=None, depth_limit=10,
-            init_reserve=1, init_refine=0, geom_resize_fact=1.0,
-            radius=0.5, center=[0.5, 0.5, 0.5],
-            data_format="RGBA",
-            extra_data=None,
-            device="cpu",
-            dtype=torch.float32,
-            map_location=None):
+    def __init__(self, wavelet_type, lowpass_depth, eval_wavelet_integral=False, 
+                 N=2, data_dim=None, depth_limit=10,
+                 init_reserve=1, init_refine=0, geom_resize_fact=1.0,
+                 radius=0.5, center=[0.5, 0.5, 0.5],
+                 data_format="RGBA",
+                 extra_data=None,
+                 device="cpu",
+                 dtype=torch.float32,
+                 map_location=None,
+                 linear_color=False,
+                 wavelet_sigma=False):
         """
         Construct N^3 Tree
 
@@ -104,13 +126,18 @@
         assert N >= 2
         assert depth_limit >= 0
         self.N : int = N
-
+        assert lowpass_depth == -1 or lowpass_depth <= init_refine, "If using lowpass, lowpass_depth must be <= init_refine, so that the lowpass gets initialized from the begining."
         if map_location is not None:
             warn('map_location has been renamed to device and may be removed')
             device = map_location
         assert dtype == torch.float32 or dtype == torch.float64, 'Unsupported dtype'
 
         self.data_format = DataFormat(data_format) if data_format is not None else None
+        self.wavelet_type = WaveletType(wavelet_type)
+        self.eval_wavelet_integral = eval_wavelet_integral
+        self.linear_color = linear_color
+        self.lowpass_depth = lowpass_depth
+        self.wavelet_sigma = wavelet_sigma
         self.data_dim : int = data_dim
         self._maybe_auto_data_dim()
         del data_dim
@@ -126,6 +153,9 @@
         self.register_buffer("parent_depth", torch.zeros(
             init_reserve, 2, dtype=torch.int32, device=device))
 
+        # for now we just set to depth_limit. We could do better by dynamically updating it on each refine.
+        self.register_buffer("max_depth", torch.tensor(depth_limit, dtype=torch.long, device=device))
+
         self.register_buffer("_n_internal", torch.tensor(1, device=device))
         self.register_buffer("_n_free", torch.tensor(0, device=device))
 
@@ -154,6 +184,11 @@
 
         self.refine(repeats=init_refine)
 
+    def get_sigma_basis(self):
+        if self.wavelet_sigma:
+            return self.wavelet_type.n_wavelets
+        else:
+            return 1
 
     # Main accesors
     def set(self, indices, values, cuda=True):
@@ -205,7 +240,7 @@
         else:
             _C.assign_vertical(self._spec(), indices, values)
 
-    def forward(self, indices, cuda=True, want_node_ids=False, world=True):
+    def forward(self, indices, cuda=True, want_node_ids=False, world=True, full_path=False):
         """
         Get tree values. Differentiable.
 
@@ -214,13 +249,30 @@
                      uses only PyTorch version.
         :param want_node_ids: if true, returns node ID for each query.
         :param world: use world space instead of :code:`[0,1]^3`, default True
+        
+        :param full_path: return full path form root to leafs, else will only return leaf, default False
 
         :return: :code:`(Q, data_dim), [(Q)]`
 
         """
         assert not indices.requires_grad  # Grad wrt indices not supported
         assert len(indices.shape) == 2
+        
+        if full_path:
+            return self._forward_full_path(indices, cuda, want_node_ids, world)
+        else:
+            return self._forward_leaf(indices, cuda, want_node_ids, world)
+
+    def _forward_full_path(self, indices, cuda=True, want_node_ids=False, world=True):
+        if not cuda or _C is None or not self.data.is_cuda:
+            raise Exception("Only cuda implementation for now!")
+        else:
+            result, node_ids = _QueryVerticalFunctionFullPath.apply(
+                                self.data, self._spec(world), indices)
+            return (result, node_ids) if want_node_ids else result
+            
 
+    def _forward_leaf(self, indices, cuda=True, want_node_ids=False, world=True):
         if not cuda or _C is None or not self.data.is_cuda:
             if not want_node_ids:
                 warn("Using slow query")
@@ -312,7 +364,10 @@
                 depth_limit=self.depth_limit,
                 geom_resize_fact=self.geom_resize_fact,
                 dtype=dtype,
-                device=device)
+                device=device,
+                wavelet_type=self.wavelet_type.wavelet_name,
+                lowpass_depth=self.lowpass_depth,
+                init_refine=self.lowpass_depth)
         def copy_to_device(x):
             return torch.empty(x.shape, dtype=x.dtype, device=device).copy_(x)
         t2.invradius = copy_to_device(self.invradius)
@@ -633,7 +688,7 @@
 
 
     # Leaf refinement & memory management methods
-    def refine(self, repeats=1, sel=None):
+    def refine(self, repeats=1, sel=None, copy_data_from_parent=True, return_data_positions=False):
         """
         Refine each selected leaf node, respecting depth_limit.
 
@@ -652,6 +707,10 @@
             memory will be wasted. We do not dedup here for efficiency reasons.
 
         """
+        assert type(copy_data_from_parent) == bool or copy_data_from_parent in ['all', 'none', 'sigma', 'zero'] or \
+            (type(copy_data_from_parent) is torch.Tensor and len(copy_data_from_parent.shape) == 1)
+        if type(copy_data_from_parent) == bool and copy_data_from_parent:
+            copy_data_from_parent = 'all'
         if self._lock_tree_structure:
             raise RuntimeError("Tree locked")
         with torch.no_grad():
@@ -664,12 +723,15 @@
                 depths = self.parent_depth[sel[0], 1]
                 # Filter by depth & leaves
                 good_mask = (depths < self.depth_limit) & (self.child[sel] == 0)
-                sel = [t[good_mask] for t in sel]
+                sel = [t[good_mask.cpu()] for t in sel]
                 leaf_node =  torch.stack(sel, dim=-1).to(device=self.data.device)
                 num_nc = len(sel[0])
                 if num_nc == 0:
                     # Nothing to do
-                    return False
+                    if return_data_positions:
+                        return False, (-1,-1)
+                    else:
+                        return False
                 new_filled = filled + num_nc
 
                 cap_needed = new_filled - self.capacity
@@ -682,11 +744,23 @@
 
                 self.child[filled:new_filled] = 0
                 self.child[sel] = new_idxs - leaf_node[:, 0].to(torch.int32)
-                self.data.data[filled:new_filled] = self.data.data[
-                        sel][:, None, None, None]
+                if copy_data_from_parent:
+                    if type(copy_data_from_parent) is torch.Tensor:
+                        # just copy selected dimensions
+                        self.data.data[filled:new_filled,..., copy_data_from_parent] = self.data.data[sel][:,copy_data_from_parent][:, None, None, None]
+                    elif copy_data_from_parent == 'sigma':
+                        self.data.data[filled:new_filled,..., -self.get_sigma_basis():] = self.data.data[sel][:,-self.get_sigma_basis():][:, None, None, None]
+                    elif copy_data_from_parent == 'all':
+                        self.data.data[filled:new_filled] = self.data.data[sel][:, None, None, None]
+                    elif copy_data_from_parent == 'zero':
+                        self.data.data[filled:new_filled] = 0
+                    else:
+                        raise Exception("copy_data_from_parent strategy '{}' not implemented".format(copy_data_from_parent))
+                else:
+                    # init with 0's
+                    self.data.data[filled:new_filled] = 0
                 self.parent_depth[filled:new_filled, 0] = self._pack_index(leaf_node)  # parent
-                self.parent_depth[filled:new_filled, 1] = self.parent_depth[
-                        leaf_node[:, 0], 1] + 1  # depth
+                self.parent_depth[filled:new_filled, 1] = self.parent_depth[leaf_node[:, 0], 1] + 1  # depth
 
                 if repeat_id < repeats - 1:
                     # Infer new selector
@@ -702,9 +776,12 @@
                 self._n_internal += num_nc
         if repeats > 0:
             self._invalidate()
-        return resized
+        if not return_data_positions:
+            return resized
+        else:
+            return resized, (filled,new_filled)
 
-    def _refine_at(self, intnode_idx, xyzi):
+    def _refine_at(self, intnode_idx, xyzi, copy_data_from_parent=True):
         """
         Advanced: refine specific leaf node. Mostly for testing purposes.
 
@@ -713,6 +790,8 @@
                     in xyz orde rto identify leaf within internal node
 
         """
+        assert type(copy_data_from_parent) == bool or copy_data_from_parent in ['all', 'none'] or \
+            (type(copy_data_from_parent) is torch.Tensor and len(copy_data_from_parent.shape) == 1)
         if self._lock_tree_structure:
             raise RuntimeError("Tree locked")
         assert min(xyzi) >= 0 and max(xyzi) < self.N
@@ -736,7 +815,13 @@
         self.parent_depth[filled, 0] = self._pack_index(torch.tensor(
             [[intnode_idx, xi, yi, zi]], dtype=torch.int32))[0]
         self.parent_depth[filled, 1] = depth
-        self.data.data[filled, :, :, :] = self.data.data[intnode_idx, xi, yi, zi]
+        if not copy_data_from_parent is False or copy_data_from_parent != 'none':
+            if type(copy_data_from_parent) is torch.Tensor:
+                # just copy selected dimensions
+                self.data.data[filled,..., copy_data_from_parent] = self.data.data[intnode_idx, xi, yi, zi,copy_data_from_parent]
+            else:
+                self.data.data[filled, :, :, :] = self.data.data[intnode_idx, xi, yi, zi]
+
         self.data.data[intnode_idx, xi, yi, zi] = 0
         self._n_internal += 1
         self._invalidate()
@@ -764,7 +849,7 @@
             free = self.parent_depth[:n_int, 0] == -1
             csum = torch.cumsum(free, dim=0)
 
-            remain_ids = torch.arange(n_int, dtype=torch.long)[~free]
+            remain_ids = torch.arange(n_int, dtype=torch.long).to(free.device)[~free]
             remain_parents = (*self._unpack_index(
                 self.parent_depth[remain_ids, 0]).long().T,)
 
@@ -832,7 +917,7 @@
         return WeightAccumulator(self, op)
 
     # Persistence
-    def save(self, path, shrink=True, compress=True):
+    def save(self, path, shrink=False, compress=True):
         """
         Save to from npz file
 
@@ -855,6 +940,12 @@
             "geom_resize_fact": self.geom_resize_fact,
             "data": self.data.data.half().cpu().numpy()  # save CPU Memory
         }
+        data['wavelet_type'] = self.wavelet_type.wavelet_name
+        data['lowpass_depth'] = self.lowpass_depth
+        data['eval_wavelet_integral'] = self.eval_wavelet_integral
+        data['linear_color'] = self.linear_color
+        data['wavelet_sigma'] = self.wavelet_sigma
+
         if self.data_format is not None:
             data["data_format"] = repr(self.data_format)
         if self.extra_data is not None:
@@ -897,7 +988,7 @@
         self[LocalIndex(idx)] = grid.reshape(-1, self.data_dim)
 
     @classmethod
-    def load(cls, path, device='cpu', dtype=torch.float32, map_location=None):
+    def load(cls, path, device='cpu', dtype=torch.float32, map_location=None, wavelet_type=None, lowpass_depth=None):
         """
         Load from npz file
 
@@ -907,13 +998,18 @@
         :param map_location: str DEPRECATED old name for device
 
         """
+        
         if map_location is not None:
             warn('map_location has been renamed to device and may be removed')
             device = map_location
         assert dtype == torch.float32 or dtype == torch.float64, 'Unsupported dtype'
-        tree = cls(dtype=dtype, device=device)
         z = np.load(path)
-        tree.data_dim = int(z["data_dim"])
+        if wavelet_type is None:
+            wavelet_type = z['wavelet_type'].item()
+        if lowpass_depth is None:
+            lowpass_depth = z['lowpass_depth'].item()
+        data_format = z['data_format'].item() if 'data_format' in z.files else None
+        tree = cls(wavelet_type=wavelet_type, lowpass_depth=lowpass_depth, init_refine=lowpass_depth, dtype=dtype, device=device, data_format=data_format)
         tree.child = torch.from_numpy(z["child"]).to(device)
         tree.N = tree.child.shape[-1]
         tree.parent_depth = torch.from_numpy(z["parent_depth"]).to(device)
@@ -926,7 +1022,7 @@
         tree.offset = torch.from_numpy(z["offset"].astype(np.float32)).to(device)
         tree.depth_limit = int(z["depth_limit"])
         tree.geom_resize_fact = float(z["geom_resize_fact"])
-        tree.data.data = torch.from_numpy(z["data"].astype(np.float32)).to(device)
+        tree.data = nn.Parameter(torch.from_numpy(z["data"].astype(np.float32)).to(device))
         if 'n_free' in z.files:
             tree._n_free.fill_(z["n_free"].item())
         else:
@@ -935,11 +1031,17 @@
                 'data_format' in z.files else None
         tree.extra_data = torch.from_numpy(z['extra_data']).to(device) if \
                           'extra_data' in z.files else None
+        tree._maybe_auto_data_dim()
+        if tree.data_dim is None:
+            tree.data_dim = int(z["data_dim"])
+        if tree.data_dim != tree.data.shape[-1]:
+            print("The loaded data does not correspond with the data_dim! You may be loading a previous version tree, make sure you know what you are doing!")
+
         return tree
 
     # Magic
     def __repr__(self):
-        return (f"svox.N3Tree(N={self.N}, data_dim={self.data_dim}, " +
+        return (f"swvox.N3Tree(N={self.N}, data_dim={self.data_dim}, " +
                 f"depth_limit={self.depth_limit}, " +
                 f"capacity:{self.n_internal - self._n_free.item()}/{self.capacity}, " +
                 f"data_format:{self.data_format or 'RGBA'})");
@@ -1095,6 +1197,7 @@
                   [0.0, 0.0, 0.0], dtype=self.data.dtype, device=self.data.device)
         tree_spec.scaling = self.invradius if world else torch.tensor(
                   [1.0, 1.0, 1.0], dtype=self.data.dtype, device=self.data.device)
+        tree_spec.max_depth = torch.tensor([self.max_depth], dtype=torch.int, device=self.data.device)
         if hasattr(self, '_weight_accum'):
             tree_spec._weight_accum = self._weight_accum if \
                     self._weight_accum is not None else torch.empty(
@@ -1103,14 +1206,17 @@
         return tree_spec
 
     def _maybe_auto_data_dim(self):
-        if self.data_format is not None and self.data_format.data_dim is not None:
+        implied_data_dim = 3 * self.data_format.basis_dim * self.wavelet_type.n_wavelets + (self.wavelet_type.n_wavelets if self.wavelet_sigma else 1)
+
+        if self.data_format is not None:
+            assert self.data_format.basis_dim is not None, "data_format does not have data_dim, this is a legacy N3Tree!"
             if self.data_dim is None:
-                self.data_dim = self.data_format.data_dim
+                self.data_dim = implied_data_dim
             else:
-                assert self.data_format.data_dim == self.data_dim, "data_dim invalid for given data format"
+                assert implied_data_dim == self.data_dim, "data_dim invalid for given data format"
         elif self.data_dim is None:
             # Legacy default
-            self.data_dim = 4
+            raise Exception("Data dim incorrect when instantiating tree!")
 
 
 # Redirect functions to N3TreeView so you can do tree.depths instead of tree[:].depths
